#!/usr/bin/env python3

import numpy as np
import numpy_extra as npe
import rospy
import sys
import threading

from ipm import IPM

from cv_bridge import *
from transformations import euler_from_quaternion, quaternion_from_euler
from sensor_msgs.msg import *
from std_msgs.msg import *
from nav_msgs.msg import *
from geometry_msgs.msg import *

DEBUG = "--debug" in sys.argv

if DEBUG:
    import cv2 # for imshow

class LocalMapperNode(object):
    def __init__(self,
      node_name = "local_mapper_node",
      ns_camera = "/camera_front",
      ns_imu = "/imu",
      ns_vehicle = "/vehicle",
      pixels_per_meter = 8.0,
      configuration = "camera_ethernet_imx291",
    ):
        self.is_initialized = False
        rospy.init_node(node_name)

        # parameters

        self.pixels_per_meter = pixels_per_meter

        # publishers

        self.pub_grid = rospy.Publisher("semantic_grid", OccupancyGrid, queue_size = 1)
        self.pub_image = rospy.Publisher("semantic_image", Image, queue_size = 1)

        # subscribers

        self.sub_semantic = rospy.Subscriber(
            "%s/semantic" % ns_camera, Image, self.on_semantic, queue_size = 1)
        self.sub_imu_data = rospy.Subscriber(
            "%s/data" % ns_imu, Imu, self.on_imu_data, queue_size = 1)
        self.sub_speed = rospy.Subscriber(
            "%s/speed" % ns_vehicle, Float32, self.on_speed, queue_size = 1)

        self.sub_odom = rospy.Subscriber(
            "%s/odom" % ns_vehicle, Odometry, self.on_odom, queue_size = 1)

        # inverse projection of the camera
        # i.e. r and theta for each pixel in camera image
        # will be populated by ipm
        self.ground_plane_r = None
        self.ground_plane_theta = None

        # fetch semantic category information
        while not rospy.has_param("%s/semantic_categories" % ns_camera):
            rospy.loginfo_throttle(10, "waiting for categories")
            rospy.sleep(1)
        self.categories = rospy.get_param("%s/semantic_categories" % ns_camera)

        # give less weight to background class
        self.semantic_weights = 2 * np.ones(len(self.categories), dtype = np.uint16)
        self.semantic_weights[0] = 1

        # color map for visualization
        # one color for each category plus the final [255,255,255] is a fake
        # category to display white-colored overlays
        self.color_map = np.array([list(reversed(c["color"])) for c in self.categories] + [[255,255,255]],
            dtype=np.uint8)

        # local map with votes per category at each spatial coordinate
        self.local_map = np.zeros((1024, 1024, len(self.categories)))

        # argmaxed version of above
        self.local_map_flat = np.zeros((1024, 1024), dtype = np.uint8)

        # current position of the car in the map in meters
        # (0, 0) is the center of the image
        self.cur_y = 0.0
        self.cur_x = 0.0
        self.cur_yaw = 0.0
        self.velocity_x = 0.0

        self.poses = np.zeros((100, 3), dtype = np.float64)

        # deals with inverse projection of the camera
        self.ipm = IPM(pixels_per_meter = self.pixels_per_meter, configuration = configuration)

        # thread locking to avoid callbacks interrupting things that should be atomic
        self.lock = threading.Lock()

        # we are ready
        self.is_initialized = True

    def on_odom(self, msg):
        pass

    def on_speed(self, msg):
        self.velocity_x = msg.data / 3.6 # km/h to m/s

    def on_imu_data(self, msg):
        dummy_roll, dummy_pitch, self.cur_yaw = euler_from_quaternion(( \
                 msg.orientation.w,
                 msg.orientation.x,
                 msg.orientation.y,
                 msg.orientation.z))

    def on_semantic(self, msg):
        if not self.is_initialized:
            return

        self.lock.acquire()
        logits = imgmsg_to_cv2(msg)

        if self.ground_plane_r is None or self.ground_plane_theta is None:
            self.ground_plane_r, self.ground_plane_theta = self.ipm.get_ground_plane(logits.shape)

            self.cutoff_top = int(logits.shape[0] * 0.4)
            self.cutoff_bottom = int(logits.shape[0] * 0.8)

        then_index = int(msg.header.stamp.nsecs / 1e7)
        then_x = self.poses[then_index, 0]
        then_y = self.poses[then_index, 1]
        then_yaw = self.poses[then_index, 2]

        logits = logits[self.cutoff_top:self.cutoff_bottom,:]
        ground_plane_r = self.ground_plane_r[self.cutoff_top:self.cutoff_bottom,:]
        ground_plane_theta = self.ground_plane_theta[self.cutoff_top:self.cutoff_bottom,:]

        # transform convert the ground_plane from base_link to map space coordinates
        values_y = then_y + ground_plane_r * np.sin(ground_plane_theta + then_yaw)
        values_x = then_x + ground_plane_r * np.cos(ground_plane_theta + then_yaw)

        # transform map space coordinates to pixel values on the local occupancy grid
        values_py = (values_y * self.pixels_per_meter + self.local_map.shape[0] / 2.0).astype(np.int16)
        values_px = (values_x * self.pixels_per_meter + self.local_map.shape[1] / 2.0).astype(np.int16)

        # remove offscreen and nan values
        # remove values where ground_plane_r is negative
        where_valid = (ground_plane_r > 0) & (values_py >= 0) & (values_px >= 0) & \
            (values_py < self.local_map.shape[0]) & (values_px < self.local_map.shape[1])
        values_py = values_py[where_valid]
        values_px = values_px[where_valid]
        logits = logits[where_valid]

        # bounds of values that require updates
        # so we aren't argmaxing a megapixel each time
        update_x_min = np.min(values_px)
        update_x_max = np.max(values_px)
        update_y_min = np.min(values_py)
        update_y_max = np.max(values_py)

        np.add.at(self.local_map, \
            (values_py, values_px, logits),
            self.semantic_weights[logits],
        )

        # full argmax
        # self.local_map_flat = np.argmax(self.local_map, axis=2)

        # partial argmax on only region that has changed
        np.argmax(
            self.local_map[update_y_min:update_y_max, update_x_min:update_x_max],
            axis = 2,
            out = self.local_map_flat[update_y_min:update_y_max, update_x_min:update_x_max],
        )

        if self.pub_grid.get_num_connections() > 0:
            msg_grid = OccupancyGrid()
            msg_grid.header.frame_id = "base_link"
            msg_grid.info.map_load_time = rospy.get_rostime()
            msg_grid.info.resolution = float(1.0/self.pixels_per_meter)
            msg_grid.info.width = int(self.local_map.shape[1])
            msg_grid.info.height = int(self.local_map.shape[0])
            msg_grid.info.origin.position.x = float(self.cur_x + self.local_map.shape[1]/self.pixels_per_meter/2.0)
            msg_grid.info.origin.position.y = float(self.cur_y + self.local_map.shape[0]/self.pixels_per_meter/2.0)
            q = quaternion_from_euler(0.0, 0.0, self.cur_yaw)
            msg_grid.info.origin.orientation.w = q[0]
            msg_grid.info.origin.orientation.x = q[1]
            msg_grid.info.origin.orientation.y = q[2]
            msg_grid.info.origin.orientation.z = q[3]
            msg_grid.data = np.ascontiguousarray(self.local_map_flat).tostring()
            self.pub_grid.publish(msg_grid)

        if self.pub_image.get_num_connections() > 0:
            py = int(self.cur_y * self.pixels_per_meter + self.local_map.shape[0] / 2)
            px = int(self.cur_x * self.pixels_per_meter + self.local_map.shape[1] / 2)

            msg_image = cv2_to_imgmsg(self.color_map[self.local_map_flat[::-1,:]], encoding = "bgr8")
            self.pub_image.publish(msg_image)

        if DEBUG:
            py = int(self.cur_y * self.pixels_per_meter + self.local_map.shape[0] / 2)
            px = int(self.cur_x * self.pixels_per_meter + self.local_map.shape[1] / 2)

            cv2.imshow('local_map', self.color_map[self.local_map_flat[::-1,:]])
            cv2.waitKey(1)

        self.lock.release()

    def shift_map(self, shift_x, shift_y):
        # shifts local_map and local_map_flat
        # shift_x, shift_y in meters

        self.lock.acquire()
        shift_px = int(shift_x * self.pixels_per_meter)
        shift_py = int(shift_y * self.pixels_per_meter)
        self.cur_x -= shift_x
        self.cur_y -= shift_y
        self.poses[:,0] -= shift_x
        self.poses[:,1] -= shift_y
        self.local_map = npe.shift_2d_replace(self.local_map, -shift_px, -shift_py, 0)
        self.local_map_flat = npe.shift_2d_replace(self.local_map_flat, -shift_px, -shift_py, 0)
        self.lock.release()

    def spin(self):
        #rospy.spin()
        rate = rospy.Rate(50)
        seq = 0
        while not rospy.is_shutdown():
            seq += 1
            rate.sleep()

            t = rospy.Time.now()

            if not self.is_initialized:
                continue

            self.cur_x += self.velocity_x * 0.01 * np.cos(self.cur_yaw)
            self.cur_y += self.velocity_x * 0.01 * np.sin(self.cur_yaw)

            indexes = (int(t.nsecs / 1e7) + np.arange(5)) % 100
            self.poses[indexes, 0] = self.cur_x
            self.poses[indexes, 1] = self.cur_y
            self.poses[indexes, 2] = self.cur_yaw

            if seq % 100 != 0:
                continue

            target_x = -np.cos(self.cur_yaw) * self.local_map.shape[1] / self.pixels_per_meter * 0.4
            target_y = -np.sin(self.cur_yaw) * self.local_map.shape[0] / self.pixels_per_meter * 0.4
            shift_x = target_x - self.cur_x
            shift_y = target_y - self.cur_y

            if shift_y != 0.0 and shift_x != 0.0:
                self.shift_map(-shift_x, -shift_y)


if __name__ == "__main__":
    node = LocalMapperNode()
    node.spin()

